Gaussian Mixture Model for Unsupervised Galaxy Classification

In the present notebook we introduce the Gaussian Mixture Model (GMM) analysis done in Souza et al. 2017 using the python scikit-learn library.

_Notes_

  • Results presented in the paper were done using R and hence small differences can be found here.
  • Most figure sometic functions are defined in a separate file ```gmm_fig_style.py``` for readability.

Global Picture and Purpose

Why Classifiy Objects?

The process of classifying objects plays a crucial role in science. By grouping objects based on their similarities we are expecting to expose underlying physical processes responsible for the classification. Nonetheless it happens often that man made classifications use empirical boundaries which are then subject to a lot of debate.
This sometimes very tedious task of classifying objects has recently benefited from technological and theoretical developments. Indeed, over the last years the rise of machine learning and especially unsupervised clustering techniques has opened up possibilities to further investigate new similarities and patterns within a data set.

BPT and WHAN diagrams

In astronomy the most well known classification scheme is probably the Hubble tuning fork that differentiates galaxies based on their morphology. This classification helped astronomers better understand the origin, evolution and dynamics of galaxies. However, it has also been found to be limiting in some cases which has prompted scientists to use additional classification schemes.
For instance when it comes to differentiating star forming (SF), active galactic nucleus (AGN) and Seyfert galaxies, astronomers measure spectral line ratios and compare them in the BPT diagram and WHAN diagram. These diagrams share a common axis which is $\log([\mathrm{NII}]/\mathrm{H}_{\alpha})$ and respectively make use of $\log([\mathrm{OIII}]/\mathrm{H}_{\beta})$ and $\log(W_{H_{\alpha}})$ as their second axis.

Our purpose hereafter is to compare an unsupervised cassification done using a Gaussian Mixture Model (GMM) to the existing BPT and WHAN diagrams classification. The GMM clustering will be done using all 83578 galaxies present in the dataset and in three dimensions: $\log(\mathrm{[NII]}/\mathrm{H}_{\alpha})$, $\log(\mathrm{[OIII]}/\mathrm{H}_{\beta})$ and $\log(W_{\mathrm{H}_{\alpha}})$

Importing and Viualizing the Data


In [1]:
# First of all, let's import some useful libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import itertools
from sklearn.mixture import GaussianMixture
from gmm_fig_style import *

# The next lines are here to force the use
# of a specific figure configuration and colormap
plt.style.use('seaborn-whitegrid')
sns.set_palette(sns.color_palette('Set2', 4))
my_cmap = mpl.cm.colors.ListedColormap(sns.color_palette('Set2', 4),
                                       name='from_list', N=4)

print('Color palette used throughout the notebook :')
sns.palplot(sns.color_palette("Set2",4))
plt.show()


Color palette used throughout the notebook :

In [2]:
# reading the catalogue
data = pd.read_csv('Catalogue/GMM_catalogue.csv')

In [3]:
# display some information
data.info()


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 83578 entries, 0 to 83577
Data columns (total 13 columns):
plate          83578 non-null int64
mjd            83578 non-null int64
fiber_id       83578 non-null int64
xx_BPT_WHAN    83578 non-null float64
yy_BPT         83578 non-null float64
yy_WHAN        83578 non-null float64
BPT_class      83578 non-null object
WHAN_class     83578 non-null object
GMM_class_4    83578 non-null int64
pGC1           83578 non-null float64
pGC2           83578 non-null float64
pGC3           83578 non-null float64
pGC4           83578 non-null float64
dtypes: float64(7), int64(4), object(2)
memory usage: 8.3+ MB

The previous command line allows us to display information about our dataset. For instance if we look at the plate line it tells us that the feature is named "plate", it contains 83578 values, non of them are null or missing and they are integers values.
To perform the GMM clustering we will use xx_BPT_WHAN, yy_BPT and yy_WHAN.


In [4]:
# A first look at the galaxy distribution in
# the BPT and WHAN diagrams
plt.figure(figsize=(16,6))

ax1 = plt.subplot(121)
data.plot.hexbin(x='xx_BPT_WHAN', y='yy_BPT', mincnt=1,
                 bins='log', gridsize=101, cmap='viridis', ax=ax1)

ax2 = plt.subplot(122)
data.plot.hexbin(x='xx_BPT_WHAN', y='yy_WHAN', mincnt=1,
                 bins='log', gridsize=101, cmap='viridis', ax=ax2)

set_plt_style()
plt.show()


GMM clustering

In our study we tried several classification methods with different parameters to finally focus on the Gaussian Mixture Models (GMM) technique fitting for 2, 3 and 4 clusters. I hereafter present how to perform such a clustering with GMM using scikit-learn (version 0.18.1).

Some Notes

  • The scikit GMM implementation does not always return groups in the same order. Therefore colors might not match for the different cases.
  • It is possible to directly ask GMM for the group to which each sample belongs to. Nonetheless we prefer here to return the probability per sample to belong to each group as it gives more flexibility in terms of cuts and potential future analysis.

Notes on Visualization

For each case we visualize the results in three different ways :

  • A scatter plot with one color per group.
  • A figure representing the 68% and 95% confidence intervals of the fitted gaussians.
  • A 2D kde interpolation per group (**WARNING**: this one can take several minutes).

Two Clusters


In [5]:
# define and train the GMM
gmm2 = GaussianMixture(covariance_type='full', n_components=2)
gmm2.fit(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])

# return the probability of belonging to a group
proba_gmm2 = gmm2.predict_proba(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])
labels2 = proba_gmm2.argmax(axis=1)

# add to the initial dataframe new columns
# containing the probability to belong to a group
data['gmm2_proba1'] = proba_gmm2[:,0]
data['gmm2_proba2'] = proba_gmm2[:,1]

In [6]:
# visualizing the resulting classification
plt.figure(figsize=(14,6))

plt.subplot(121)
plt.scatter(data.xx_BPT_WHAN, data.yy_BPT, c=labels2, s=0.4, cmap=my_cmap)

plt.subplot(122)
plt.scatter(data.xx_BPT_WHAN, data.yy_WHAN, c=labels2, s=0.4, cmap=my_cmap)

set_plt_style()
plt.show()



In [7]:
plt.figure(figsize=(14,6))

ax1 = plt.subplot(121)
for i in enumerate(['C0','C3']):
    plot_BPT_ell(gmm2.covariances_[i[0]], gmm2.means_[i[0]], ax=ax1, col=i[1])

ax2 = plt.subplot(122)
for i in enumerate(['C0','C3']):
    plot_WHAN_ell(gmm2.covariances_[i[0]], gmm2.means_[i[0]], ax=ax2, col=i[1])
    
set_plt_style()
plt.show()


Seaborn allows to visualize the density distribution using kde. In the following plot we only focus on point with a probability to belong to a group superior to 50%.


In [8]:
plt.figure(figsize=(14,6))

ax1 = plt.subplot(121)
sns.kdeplot(data[data.gmm2_proba1>0.5].xx_BPT_WHAN, data[data.gmm2_proba1>0.5].yy_BPT,
            bw='scott', n_levels=5, ax=ax1, cmap='Greens_d')
sns.kdeplot(data[data.gmm2_proba2>0.5].xx_BPT_WHAN, data[data.gmm2_proba2>0.5].yy_BPT,
            bw='scott', n_levels=5, ax=ax1, cmap='RdPu_d')

ax2 = plt.subplot(122)
sns.kdeplot(data[data.gmm2_proba1>0.5].xx_BPT_WHAN, data[data.gmm2_proba1>0.5].yy_WHAN,
            bw='scott', n_levels=5, ax=ax2, cmap='Greens_d')
sns.kdeplot(data[data.gmm2_proba2>0.5].xx_BPT_WHAN, data[data.gmm2_proba2>0.5].yy_WHAN,
            bw='scott', n_levels=5, ax=ax2, cmap='RdPu_d')

set_plt_style()
plt.show()


3 Clusters


In [9]:
# define and train GMM
gmm3 = GaussianMixture(covariance_type='full', n_components=3)
gmm3.fit(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])

# return the probability of belonging to a group
proba_gmm3 = gmm3.predict_proba(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])
labels3 = proba_gmm3.argmax(axis=1)

# add to the initial dataframe new columns
# containing the probability to belong to a group
data['gmm3_proba1'] = proba_gmm3[:,0]
data['gmm3_proba2'] = proba_gmm3[:,1]
data['gmm3_proba3'] = proba_gmm3[:,2]

In [10]:
plt.figure(figsize=(14,6))

plt.subplot(121)
plt.scatter(data.xx_BPT_WHAN, data.yy_BPT, c=labels3, s=0.4, cmap=my_cmap)

plt.subplot(122)
plt.scatter(data.xx_BPT_WHAN, data.yy_WHAN, c=labels3, s=0.4, cmap=my_cmap)

set_plt_style()
plt.show()



In [11]:
plt.figure(figsize=(14,6))

ax1 = plt.subplot(121)
for i in enumerate(['C0','C2','C3']):
    plot_BPT_ell(gmm3.covariances_[i[0]], gmm3.means_[i[0]], ax=ax1, col=i[1])

ax2 = plt.subplot(122)
for i in enumerate(['C0','C2','C3']):
    plot_WHAN_ell(gmm3.covariances_[i[0]], gmm3.means_[i[0]], ax=ax2, col=i[1])

set_plt_style()
plt.show()



In [12]:
plt.figure(figsize=(14,6))

ax1 = plt.subplot(121)
sns.kdeplot(data[data.gmm3_proba1>0.5].xx_BPT_WHAN, data[data.gmm3_proba1>0.5].yy_BPT,
            bw='scott', n_levels=5, ax=ax1, cmap='Greens_d')
sns.kdeplot(data[data.gmm3_proba2>0.5].xx_BPT_WHAN, data[data.gmm3_proba2>0.5].yy_BPT,
            bw='scott', n_levels=5, ax=ax1, cmap='Blues_d')
sns.kdeplot(data[data.gmm3_proba3>0.5].xx_BPT_WHAN, data[data.gmm3_proba3>0.5].yy_BPT,
            bw='scott', n_levels=5, ax=ax1, cmap='RdPu_d')

ax2 = plt.subplot(122)
sns.kdeplot(data[data.gmm3_proba1>0.5].xx_BPT_WHAN, data[data.gmm3_proba1>0.5].yy_WHAN,
            bw='scott', n_levels=5, ax=ax2, cmap='Greens_d')
sns.kdeplot(data[data.gmm3_proba2>0.5].xx_BPT_WHAN, data[data.gmm3_proba2>0.5].yy_WHAN,
            bw='scott', n_levels=5, ax=ax2, cmap='Blues_d')
sns.kdeplot(data[data.gmm3_proba3>0.5].xx_BPT_WHAN, data[data.gmm3_proba3>0.5].yy_WHAN,
            bw='scott', n_levels=5, ax=ax2, cmap='RdPu_d')

set_plt_style()
plt.show()


4 Clusters


In [13]:
# define and train GMM
gmm4 = GaussianMixture(covariance_type='full', n_components=4)
gmm4.fit(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])

# return the probability of belonging to a group
proba_gmm4 = gmm4.predict_proba(data[['xx_BPT_WHAN','yy_BPT','yy_WHAN']])
labels4 = proba_gmm4.argmax(axis=1)

# add to the initial dataframe new columns
# containing the probability to belong to a group
data['gmm4_proba1'] = proba_gmm4[:,0]
data['gmm4_proba2'] = proba_gmm4[:,1]
data['gmm4_proba3'] = proba_gmm4[:,2]
data['gmm4_proba4'] = proba_gmm4[:,3]
data['GMM4_group'] = labels4

In [14]:
plt.figure(figsize=(14,6))

plt.subplot(121)
plt.scatter(data.xx_BPT_WHAN, data.yy_BPT, c=labels4, s=0.4, cmap=my_cmap)

plt.subplot(122)
plt.scatter(data.xx_BPT_WHAN, data.yy_WHAN, c=labels4, s=0.4, cmap=my_cmap)

set_plt_style()
plt.show()



In [15]:
plt.figure(figsize=(14,6))

ax1 = plt.subplot(121)
for i in enumerate(['C0','C1','C2','C3']):
    plot_BPT_ell(gmm4.covariances_[i[0]], gmm4.means_[i[0]], ax=ax1, col=i[1])

ax2 = plt.subplot(122)
for i in enumerate(['C0','C1','C2','C3']):
    plot_WHAN_ell(gmm4.covariances_[i[0]], gmm4.means_[i[0]], ax=ax2, col=i[1])

set_plt_style()
plt.show()



In [16]:
plt.figure(figsize=(14,6))

ax1 = plt.subplot(121)
sns.kdeplot(data[data.gmm4_proba1>0.5].xx_BPT_WHAN, data[data.gmm4_proba1>0.5].yy_BPT,
            bw='scott', n_levels=5, ax=ax1, cmap='Greens_d')
sns.kdeplot(data[data.gmm4_proba2>0.5].xx_BPT_WHAN, data[data.gmm4_proba2>0.5].yy_BPT,
            bw='scott', n_levels=5, ax=ax1, cmap='Oranges_d')
sns.kdeplot(data[data.gmm4_proba3>0.5].xx_BPT_WHAN, data[data.gmm4_proba3>0.5].yy_BPT,
            bw='scott', n_levels=5, ax=ax1, cmap='Blues_d')
sns.kdeplot(data[data.gmm4_proba4>0.5].xx_BPT_WHAN, data[data.gmm4_proba4>0.5].yy_BPT,
            bw='scott', n_levels=5, ax=ax1, cmap='RdPu_d')

ax2 = plt.subplot(122)
sns.kdeplot(data[data.gmm4_proba1>0.5].xx_BPT_WHAN, data[data.gmm4_proba1>0.5].yy_WHAN,
            bw='scott', n_levels=5, ax=ax2, cmap='Greens_d')
sns.kdeplot(data[data.gmm4_proba2>0.5].xx_BPT_WHAN, data[data.gmm4_proba2>0.5].yy_WHAN,
            bw='scott', n_levels=5, ax=ax2, cmap='Oranges_d')
sns.kdeplot(data[data.gmm4_proba3>0.5].xx_BPT_WHAN, data[data.gmm4_proba3>0.5].yy_WHAN,
            bw='scott', n_levels=5, ax=ax2, cmap='Blues_d')
sns.kdeplot(data[data.gmm4_proba4>0.5].xx_BPT_WHAN, data[data.gmm4_proba4>0.5].yy_WHAN,
            bw='scott', n_levels=5, ax=ax2, cmap='RdPu_d')

set_plt_style()
plt.show()


LDA Analysis

In this section we perform a Linear Discriminant Analysis (LDA) as an external cluster validation technique (see sec.5 of the paper for more details). We only present hereafter present the case with 4 clusters.


In [17]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
lda = LinearDiscriminantAnalysis(n_components=2)

In [18]:
# defining some useful functions
def lda_BPT_validation(bpt_grp, gmm_grp):
    
    bpt_class = data[data.BPT_class==bpt_grp][['xx_BPT_WHAN','yy_BPT']]
    bpt_class['type'] = np.repeat(bpt_grp, repeats=len(bpt_class))
    
    gmm_gc = data[data.GMM4_group==gmm_grp][['xx_BPT_WHAN','yy_BPT']]
    gmm_gc['type'] = np.repeat('gc{}'.format(gmm_grp), repeats=len(gmm_gc)) 

    tmp = pd.concat([bpt_class,gmm_gc])
    lda.fit(tmp.drop('type', axis=1), tmp.type)
    
    bpt_lda = lda.transform(bpt_class.drop('type', axis=1))
    gc_lda = lda.transform(gmm_gc.drop('type', axis=1))

    return bpt_lda.ravel(), gc_lda.ravel()

def lda_WHAN_validation(whan_grp, gmm_grp):
    
    whan_class = data[data.WHAN_class==whan_grp][['xx_BPT_WHAN','yy_WHAN']]
    whan_class['type'] = np.repeat(whan_grp, repeats=len(whan_class))
    
    gmm_gc = data[data.GMM4_group==gmm_grp][['xx_BPT_WHAN','yy_WHAN']]
    gmm_gc['type'] = np.repeat('gc{}'.format(gmm_grp), repeats=len(gmm_gc)) 

    tmp = pd.concat([whan_class,gmm_gc])
    lda.fit(tmp.drop('type', axis=1), tmp.type)
    
    whan_lda = lda.transform(whan_class.drop('type', axis=1))
    gc_lda = lda.transform(gmm_gc.drop('type', axis=1))

    return whan_lda.ravel(), gc_lda.ravel()

Comparison with the BPT classification


In [19]:
plt.figure(figsize=(14,12))

i = 0
for bpt,gc in itertools.product(['SF','Composite','AGN'],enumerate(['C0','C1','C2','C3'])):
    i += 1
    bpt_lda, gc_lda = lda_BPT_validation(bpt,gc[0])
    plt.subplot(3,4,i)
    sns.kdeplot(bpt_lda, shade=True, color='k')
    sns.kdeplot(gc_lda, shade=True, color=gc[1])
    
lda_bpt_plt_style()


Comparison with the WHAN classification


In [20]:
plt.figure(figsize=(14,16))

i = 0
for whan,gc in itertools.product(['SF','sAGN','wAGN','retired'],enumerate(['C0','C1','C2','C3'])):
    i += 1
    whan_lda, gc_lda = lda_WHAN_validation(whan,gc[0])
    plt.subplot(4,4,i)
    sns.kdeplot(whan_lda, shade=True, color='k')
    sns.kdeplot(gc_lda, shade=True, color=gc[1])
    
lda_whan_plt_style()



In [ ]: